Skip to content

Remove additional weight prefetching all gathers #3412

Merged
copybara-service[bot] merged 7 commits intomainfrom
chengnuojin-pp-more
Apr 2, 2026
Merged

Remove additional weight prefetching all gathers #3412
copybara-service[bot] merged 7 commits intomainfrom
chengnuojin-pp-more

Conversation

@NuojCheng
Copy link
Copy Markdown
Collaborator

@NuojCheng NuojCheng commented Mar 13, 2026

Description

This PR does following 2 things:

  • Enable real buffered sliding window (BSW) for circular pipeline parallelism: we only need to gather one layer's weight and scan carry the old weight without incurring memory burden, see description in my doc go/maxtext-spmd-p-2026.
  • Add context paralleism (CP) support for custom mesh and logical rule as well as weight prefetching of circular pipeline parallelism.

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 13, 2026

Codecov Report

❌ Patch coverage is 98.36066% with 1 line in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/utils/pipeline_utils.py 97.87% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

@NuojCheng NuojCheng force-pushed the chengnuojin-pp-more branch from 2b3da04 to 680f1ad Compare March 17, 2026 23:53
@NuojCheng NuojCheng force-pushed the chengnuojin-pp-more branch 4 times, most recently from 2d66c61 to 0c52610 Compare March 30, 2026 20:48
@NuojCheng NuojCheng force-pushed the chengnuojin-pp-more branch 3 times, most recently from a1f6db3 to 8912464 Compare March 30, 2026 23:21
@NuojCheng NuojCheng added gemini-review and removed draft Draft PR labels Mar 30, 2026
@github-actions
Copy link
Copy Markdown

🤖 Hi @NuojCheng, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request introduces significant refactoring to the circular pipeline parallelism implementation in MaxText, specifically targeting weight prefetching efficiency and communication overlap. The transition to a custom_vjp for pipeline stage execution is a sophisticated improvement that allows for better control over memory and communication during both forward and backward passes.

🔍 General Feedback

  • Significant Logic Refactoring: The transition to a sliding window of size 2 for block-sharded weights (w_curr, w_next) correctly addresses the pipeline delay and is much more efficient than fetching both current and next weights at every repeat.
  • Custom VJP Implementation: The custom_vjp in pipeline_utils.py correctly handles the linear transposition of the prefetching logic, ensuring that gradients for the pipeline weights are accumulated properly through the repeats.
  • Correctness Concern: A critical sharding mismatch was identified in gather_microbatch_inputs_vmap when ShardMode.EXPLICIT is used. This should be addressed before merging.
  • Config Improvement: The new pipeline-large-moe-cp.yml configuration correctly adapts logical axes for large-scale MoE jobs, reflecting DeepSeek-style model structures.

Comment thread src/maxtext/layers/pipeline.py
Comment thread src/maxtext/layers/pipeline.py Outdated
Comment thread src/maxtext/utils/pipeline_utils.py Outdated
Comment thread src/maxtext/utils/pipeline_utils.py Outdated
Copy link
Copy Markdown
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome!

Is there an existing test that protects the correctness of this? If not we definitely should add one in this PR

@NuojCheng NuojCheng force-pushed the chengnuojin-pp-more branch from efda1e8 to 60dbacd Compare March 31, 2026 03:36
@NuojCheng
Copy link
Copy Markdown
Collaborator Author

This is awesome!

Is there an existing test that protects the correctness of this? If not we definitely should add one in this PR

The loss and gradients correctness are protected by this test:

@pytest.mark.tpu_only
def test_circular_pipeline_ag_per_repeat(self):
# 2 stages, 8 microbatches, enable pipeline ag per repeat
config = pyconfig.initialize(
[sys.argv[0], get_test_config_path()],
enable_checkpointing=False,
enable_goodput_recording=False,
run_name="circular_ag_per_repeat",
max_target_length=128,
base_emb_dim=28,
ici_pipeline_parallelism=2,
base_num_decoder_layers=8,
num_pipeline_microbatches=8,
per_device_batch_size=4,
pipeline_fsdp_ag_per_repeat=True,
)
self.assert_pipeline_same_output_and_grad(config)

@NuojCheng NuojCheng force-pushed the chengnuojin-pp-more branch from 69c99bd to 049ba3f Compare April 2, 2026 18:29
@copybara-service copybara-service Bot merged commit c38fa86 into main Apr 2, 2026
24 of 26 checks passed
@copybara-service copybara-service Bot deleted the chengnuojin-pp-more branch April 2, 2026 19:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants